################################################################################
#
# Package: machinelearningtools
# Purpose: Provide convenience functions for machine learning with caret
#
################################################################################
################################################################################
# set model input > formula
################################################################################
set_formula <- function(target_label, features) {
  features %>%
    paste(collapse = " + ") %>%
    paste(target_label, "~", .) %>%
    as.formula(env = .GlobalEnv)
}
################################################################################
# get_model_metrics:
#   calculate training set performance:
#   mean & sd for all model objects in model_list
#
# set color by
##
##  palette:
##    models.list %>% get_model_metrics(palette = "Dark2")
##
##  color codes:
##    models.list %>% get_model_metrics(
##      colors = c("#4DAF4A", "#E41A1C", "#FF7F00", "#377EB8"))
##
##  colors: "#4DAF4A" green "#377EB8" blue "#E41A1C" red "#FF7F00" orange
##
################################################################################
get_model_metrics <- function(
  models_list, target_label = NULL, testing_set = NULL,
  median_sort = FALSE, reverse = FALSE,
  palette = "Set1", colors = NULL,
  boxplot_fill = "grey95", boxplot_color = "grey25") {
  require(RColorBrewer)
  # retrieve target.label & testing.set from models_list
  target.label <- if (!is.null(target_label)) target_label else models_list$target.label
  # set testing set to argument > from models_list > NULL if empty
  if (!is.null(testing_set)) {
    testing.set <- testing_set
  } else if (is.null(models_list$testing.set)) {
    testing.set <- NULL
  } else if (nrow(models_list$testing.set) != 0) {
    testing.set <- models_list$testing.set
  } else { # e.g. if testingset exits but 0 rows
    testing.set <- NULL
  }
  # remove target.label + testing.set from models.list
  if (!is.null(models_list$target.label)) {
    models_list %<>% purrr::list_modify("target.label" = NULL)
  }
  if (!is.null(models_list$testing.set)) {
    models_list %<>% purrr::list_modify("testing.set" = NULL)
  }
  target <- models_list[[1]]$trainingData$.outcome
  if (is.factor(target)) {
    metric1 = "Accuracy"
    metric2 = "Kappa"
    metric3 = NULL
    metric1.descending = FALSE
    metric2.descending = FALSE
    metric3.descending = FALSE
  } else if (is.numeric(target)) {
    metric1 = "RMSE"
    metric2 = "Rsquared"
    metric3 = "R"
    metric1.descending = TRUE
    metric2.descending = FALSE
    metric3.descending = FALSE
  }
  if (reverse) {
    metric1.descending
    metric2.descending
    metric3.descending
  }
  ### get metrics from original resamples' folds
  resamples.values <- models_list %>% caret::resamples() %>% .$values %>%
    # select_if(is.numeric) %>%
    # retrieve RMSE, Rsquared but not MAE
    ## tricky: select without dplyr:: prefix does NOT work
    # dplyr::select(ends_with("RMSE"), ends_with("Rsquared"))
    dplyr::select(ends_with(metric1), ends_with(metric2)) %>%
    # calculate R from R-squared variables
    dplyr::mutate(
      across(
        .cols = ends_with(metric2), # R-squared
        .fns = sqrt,
        .names = "{.col}.R"
      )
    ) %>%
    set_names(gsub(paste0(metric2, ".R"), "R", names(.)))
  ### calculate mean and sd for each metric
  metric1.training <- get_metric_from_resamples(
    resamples.values, metric1, median_sort)
  metric2.training <- get_metric_from_resamples(
    resamples.values, metric2, median_sort)
  metric3.training <- switch (
    !is.null(metric3), # instead of: is.numeric(target) = test on a vector
    get_metric_from_resamples(resamples.values, metric3, median_sort),
    NULL
  )
  ### visualize the resampling distribution from cross-validation
  metric1.resamples.boxplots <- visualize_resamples_boxplots(
    resamples.values, metric1, palette, colors = colors, metric1.descending)
  metric2.resamples.boxplots <- visualize_resamples_boxplots(
    resamples.values, metric2, palette, colors = colors, metric2.descending)
  metric3.resamples.boxplots <- switch (
    !is.null(metric3),
    visualize_resamples_boxplots(
      resamples.values, metric3, palette, colors = colors, metric3.descending),
    NULL)
  if (!is.null(testing.set)) {
    metrics.testing <- get_testingset_performance(
      models_list, target.label, testing.set
    )
  } else {
    metrics.testing <- NULL
  }
  if (is.factor(target)) { # classification
    benchmark.all <- merge(metric1.training, metric2.training, by = "model") %>%
      {
        if (!is.null(metrics.testing)) {
          # tricky: within conditional {} block, must reference to LHS (.)
          merge(., metrics.testing, by = "model") %>%
            arrange(desc(Acc.testing))
        } else {
          .
        }
      } %>%
      as_tibble(.)
  } else if (is.numeric(target)) { # regression
    benchmark.all <- merge(metric1.training, metric2.training, by = "model") %>%
      merge(metric3.training, by = "model") %>%
      {
        if (!is.null(metrics.testing)) {
          # tricky: within conditional {} block, must reference to LHS (.)
          merge(., metrics.testing, by = "model") %>%
            dplyr::mutate(RMSE.delta = RMSE.testing - RMSE.mean) %>%
            arrange(RMSE.testing)
        } else {
          .
        }
      } %>%
      as_tibble(.)
  }
  return(list(metric1 = metric1,
              metric2 = metric2,
              resamples.values = resamples.values,
              metric1.training = metric1.training,
              metric2.training = metric2.training,
              metric3.training = metric3.training,
              metric1.resamples.boxplots = metric1.resamples.boxplots,
              metric2.resamples.boxplots = metric2.resamples.boxplots,
              metric3.resamples.boxplots = metric3.resamples.boxplots,
              metrics.testing = metrics.testing,
              benchmark.all = benchmark.all
  ))
}
################################################################################
# get_metric_from_resamples
# Helper function for get_model_metrics
################################################################################
get_metric_from_resamples <- function(
  resamples_values, metric, median_sort = FALSE) {
  require(dplyr)
  suffix <- paste0("~", metric)
  # tricky: for arrange, convert string column name to symbol, not quosure
  # https://stackoverflow.com/a/26497839/7769076
  metric.mean <- rlang::sym(paste0(metric,".mean"))
  metric.sd <- paste0(metric,".sd")
  metric.median <- rlang::sym(paste0(metric,".median"))
  sort.metric <- ifelse(median_sort, metric.median, metric.mean)
  resamples_values %>%
    dplyr::select(ends_with(suffix)) %>%
    rename_with(~gsub(suffix, "", .)) %>%
    summarize(across(everything(),
                     list(median = median, mean = mean, sd = sd))) %>%
    # genius tip (.value!): https://stackoverflow.com/a/58880309/7769076
    pivot_longer(
      cols = everything(),
      names_to = c("model", ".value"),
      names_pattern =  "(.+)_(.+$)"
    ) %>%
    set_names(c(
      "model",
      as.character(metric.median),
      as.character(metric.mean),
      metric.sd
    )) %>%
    { # first columns mean+sd if not sorted by median
      if (!median_sort) {
        select(., model, ends_with("mean"), ends_with("sd"), ends_with("median"))
      } else { . }
    } %>%
    {
      if (metric == "RMSE") {
        # tricky: unquote symbol, not quosure
        # tricky: must use . inside inline dplyr code {}
        arrange(., !!sort.metric)
      } else { # for Accuracy, Kappa AND Rsquared: sort by descending order
        arrange(., desc(!!sort.metric))
      }
    }
}
################################################################################
# get_metric_resamples
# Helper function for tidyposterior
################################################################################
get_metric_resamples <- function(resamples_data, metric) {
  resamples_data %>%
    .$values %>%
    as_tibble() %>%
    select(Resample, contains(metric)) %>%
    # tricky: tilde (~) NOT dash (-)
    setNames(gsub(paste0("~", metric), "", names(.))) %>%
    dplyr::rename(id = Resample)
}
################################################################################
# visualize_resamples_boxplots()
# Helper function for get_model_metrics
################################################################################
visualize_resamples_boxplots <- function(
  resamples_values,
  METRIC,
  palette = "Set1",
  descending = FALSE,
  color_count = NULL,
  dot_size = NULL,
  boxplot_fill = "grey95",
  boxplot_color = "grey25",
  colors = NULL,
  exclude_light_hues = NULL
) {
  require(dplyr)
  require(ggplot2)
  require(RColorBrewer)
  # dot size of resamples distribution is indirectly proportional to their #
  if (is.null(dot_size)) dot_size <- 1/logb(nrow(resamples_values), 5)
  # extract the resamples values for selected METRIC (e.g. "Accuracy" or "RMSE")
  resamples.by.metric <- resamples_values %>%
    dplyr::select(ends_with(METRIC)) %>%
    purrr::set_names(~ gsub(paste0("~", METRIC), "", .)) %>%
    drop_na() %>%
    pivot_longer(
      cols = everything(),
      names_to = "model",
      values_to = METRIC,
      names_transform = list(model = as.factor)
    )
  # create HEX color codes from palette with 8+ colors
  ## Source: http://novyden.blogspot.com/2013/09/how-to-expand-color-palette-with-ggplot.html
  color.codes <- brewer.pal(8, palette)
  # remove the first color codes of palette as they have very light hues
  if (!is.null(exclude_light_hues)) {
    color.codes %<>% .[-c(1:exclude_light_hues)]
  }
  # the # colors needed depends on # extracted resamples for selected METRIC
  if (is.null(color_count)) color_count <- ncol(resamples_values)
  # generate the color palette by extrapolation from color.codes to color_count
  color.palette.generated <- colorRampPalette(color.codes)(color_count)
  resamples.boxplots <- resamples.by.metric %>%
    ggplot(aes(
      {
        if (descending) {
          x = reorder(model, desc(!!sym(METRIC)), median)
        } else {
          x = reorder(model, !!sym(METRIC), median)
        }
      },
      y = !!sym(METRIC),
      color = model
    )) +
    geom_boxplot(
      width = 0.7,
      fill = boxplot_fill,
      color = boxplot_color,
      alpha = 0.3
    ) +
    geom_jitter(size = dot_size) +
    coord_flip() +
    scale_color_manual(
      values = if (!is.null(colors)) colors else color.palette.generated
    ) +
    labs(x = "model", y = METRIC) +
    theme_minimal() +
    theme(
      legend.position = "none",
      axis.title = element_text(size = 14),
      axis.text = element_text(size = 14)
    )
  return(resamples.boxplots)
}
#######################################################################
# define string in filename
#######################################################################
logical_string <- function(logical_flag, true_string) {
  if (logical_flag) true_string else NULL
}
#######################################################################
# benchmark algorithms with caret::train
#######################################################################
benchmark_algorithms <- function(
  target_label,
  features_labels,
  training_set,
  testing_set,
  formula_input = FALSE,
  preprocess_configuration = c("center", "scale", "zv"),
  training_configuration,
  impute_method = NULL,
  algorithm_list,
  glm_family = NULL,
  seed = 17,
  cv_repeats,
  try_first = NULL,
  models_list_name = NULL,
  cluster_log = "",
  beep = TRUE,
  push = TRUE) {
  ########################################
  ## 2.3 Select the target & features
  ########################################
  target_label %>% print
  features_labels %>% print
  ########################################
  # 3.2: Select the target & features
  ########################################
  target <- training_set[[target_label]]
  # avoid tibble e.g. for svmRadial: "setting rownames on tibble is deprecated"
  features <- training_set %>% select(features_labels) %>% as.data.frame
  if (!is.null(try_first) & is.numeric(try_first)) {
    target %<>% head(try_first)
    features %<>% head(try_first)
    training_set %<>% head(try_first)
  }
  ########################################
  # 3.3: Train the models
  ########################################
  models.list <- list()
  if (formula_input) {
    print("******** FORMULA interface")
    # define formula
    formula1 <- set_formula(target_label, features_labels)
    system.time(
      models.list <- algorithm_list %>%
        map(function(algorithm_label) {
          print(paste("***", algorithm_label))
          ############ START new cluster for model training
          cluster.new <- clusterOn(outfile_name = cluster_log)
          # stop cluster if training throws error (https://stackoverflow.com/a/41679580/7769076)
          on.exit(if (exists("cluster.new")) { clusterOff(cluster.new) } )
          if (algorithm_label == "rf") {
            model <- train(
              form = formula1,
              method = algorithm_label,
              data = training_set,
              preProcess = preprocess_configuration,
              trControl = training_configuration,
              importance = TRUE
            )
          } else if (algorithm_label == "ranger") {
            model <- train(
              form = formula1,
              method = algorithm_label,
              data = training_set,
              preProcess = preprocess_configuration,
              trControl = training_configuration,
              importance = "impurity"
            )
          } else if (algorithm_label == "glm" | algorithm_label == "glmnet") {
            model <- train(
              form = formula1,
              method = algorithm_label,
              family = glm_family,
              data = training_set,
              preProcess = preprocess_configuration,
              trControl = training_configuration
            )
          } else {
            model <- train(
              form = formula1,
              method = algorithm_label,
              data = training_set,
              preProcess = preprocess_configuration,
              trControl = training_configuration
            )
          }
          ############ END model training & STOP cluster
          clusterOff(cluster.new)
          stopImplicitCluster()
          return(model)
        }) %>%
        setNames(algorithm_list)
    ) %T>% {
      if (beep) beepr::beep()
      if (push) push_message(
        time_in_seconds = .["elapsed"],
        algorithm_list = algorithm_list,
        models_list_name = models_list_name
      )
    }
    # categorical variables -> x,y interface
  } else {
    print("******** X Y INTERFACE")
    # transform categorical features by one-hot-encoding for models except rf, ranger, gbm
    # e.g. glmnet expects features as model.matrix (source: https://stackoverflow.com/a/48230658/7769076)
    if (contains_factors(training_set)) {
      formula1 <- set_formula(target_label, features_labels)
      features.onehot <- model.matrix(formula1, data = training_set) %>%
        as.data.frame() %>%
        select(-`(Intercept)`)
      # training.set.onehot <- cbind(target, features.onehot)
    }
    # backup original features before loop to avoid overriding
    features.original <- features
    # training.set.original <- training_set
    system.time(
      models.list <- algorithm_list %>%
        map(function(algorithm_label) {
          print(paste("***", algorithm_label))
          # transform factors by one-hot-encoding for all models except rf, ranger, gbm
          if (contains_factors(training_set) &
              !handles_factors(algorithm_label)
              & !algorithm_label %in% c("svmRadial", "svmLinear")
          ) {
            features <- features.onehot
            # training.set <- training.set.onehot
            print(paste("*** performed one-hot-encoding for model", algorithm_label))
          } else { # no onehot-encoding
            features <- features.original
            # training.set <- training.set.original
          }
          ############ START new cluster for model training
          cluster.new <- clusterOn(outfile_name = cluster_log)
          # stop cluster if training throws error (https://stackoverflow.com/a/41679580/7769076)
          on.exit(if (exists("cluster.new")) { clusterOff(cluster.new) } )
          if (algorithm_label == "rf") {
            model <- train(
              x = features,
              y = target,
              method = algorithm_label,
              preProcess = preprocess_configuration,
              trControl = training_configuration,
              importance = TRUE
            )
          } else if (algorithm_label == "ranger") {
            model <- train(
              x = features,
              y = target,
              method = algorithm_label,
              preProcess = preprocess_configuration,
              trControl = training_configuration,
              importance = "impurity"
            )
          } else if (class(target) == "factor" &
                     (algorithm_label == "glm" | algorithm_label == "glmnet")
          ) {
            model <- train(
              x = features,
              y = target,
              method = "glm",
              family = glm_family,
              preProcess = preprocess_configuration,
              trControl = training_configuration
            )
          } else if (algorithm_label == "xgbTree" | algorithm_label == "xgbLinear") {
            model <- train(
              x = features,
              y = target,
              method = algorithm_label,
              nthread = 1,
              preProcess = preprocess_configuration,
              trControl = training_configuration
            )
          } else if (algorithm_label == "svmRadial" | algorithm_label == "svmLinear") {
            # predict() requires kernlab::ksvm object created by formula:
            # https://stackoverflow.com/q/52743663/7769076
            formula.svm <- set_formula(target_label, features_labels)
            model <- train(
              form = formula.svm,
              method = algorithm_label,
              data = training_set,
              preProcess = preprocess_configuration,
              trControl = training_configuration
            )
          } else {
            model <- train(
              x = features,
              y = target,
              method = algorithm_label,
              preProcess = preprocess_configuration,
              trControl = training_configuration
            )
          }
          ############ END model training & STOP cluster
          clusterOff(cluster.new)
          stopImplicitCluster()
          return(model)
        }) %>%
        setNames(algorithm_list)
    ) %T>% {
      if (beep) beepr::beep()
      if (push) push_message(
        time_in_seconds = .["elapsed"],
        algorithm_list = algorithm_list,
        models_list_name = if (!is.null(models_list_name)) models_list_name else NULL
      )
    }
  }
  ########################################
  # Postprocess the models
  ########################################
  # add target.label & testing.set to models.list
  models.list$target.label <- target_label
  models.list$testing.set <- testing_set
  # save the models.list
  if (!is.null(models_list_name)) {
    models.list %>% saveRDS(models_list_name)
    print(paste("model training results saved in", models_list_name))
  }
  print(models.list)
  return(models.list)
}
################################################################################
# Dataset contains Factors
# check if dataset contains categorical features
################################################################################
contains_factors <- function(data) {
  data %>%
    select_if(is.factor) %>%
    names %>%
    {length(.) > 0}
}
################################################################################
# Algorithm handles Factors
# Check if algorithm handles categorical features without one-hot-encoding
################################################################################
handles_factors <- function(algorithm_label) {
  # models that can handle factors instead of one-hot-encoding
  algorithms.handling.factors <- c(
    "rf", "ranger", "gbm", "nnet"
  )
  # check whether imput algorithm handles factors
  algorithm_label %in% algorithms.handling.factors
}
################################################################################
# Get feature set
# From vector of feature labels, generate feature set
################################################################################
get_featureset <- function(data,
                           target_label = NULL,
                           featureset_labels = NULL,
                           select_starts = NULL) {
  data %>%
    dplyr::select(!!rlang::sym(target_label)) %>%
    {
      if (!is.null(featureset_labels)) {
        cbind(.,
              data %>%
                dplyr::select(!!!rlang::syms(featureset_labels))
        )
      } else { . }
    } %>%
    {
      if (!is.null(select_starts)) {
        cbind(.,
              map_dfc(select_starts, function(start_keyword) {
                data %>%
                  select(starts_with(start_keyword))
              })
        )
      } else { . }
    } %>%
    as_tibble()
}
################################################################################
# Get Testing Set Performance
# calculate RMSE for all model objects in model_list
################################################################################
get_testingset_performance <- function(
  models_list, target_label = NULL, testing_set = NULL) {
  # remove target.label + testing.set from models.list
  if (!is.null(models_list$target.label) & !is.null(models_list$testing.set)) {
    target.label <- models_list$target.label
    testing.set <- models_list$testing.set
    models_list %<>% purrr::list_modify("target.label" = NULL, "testing.set" = NULL)
  } else if (!is.null(target_label) & !is.null(testing_set)) {
    target.label <- target_label
    testing.set <- testing_set
  }
  features.labels <- testing.set %>% select(-target.label) %>% names
  observed <- testing.set[[target.label]]
  # do onehot encoding for algorithms that cannot handle factors
  if (contains_factors(testing.set)) {
    formula1 <- set_formula(target.label, features.labels)
    testing.set.onehot <- model.matrix(formula1, data = testing.set) %>%
      as_tibble() %>%
      select(-`(Intercept)`)
  }
  if (is.factor(observed)) {
    models_list %>%
      map(
        function(model_object) {
          # set flag for onehot encoding
          onehot <- FALSE
          # do onehot encoding for algorithms that cannot handle factors
          if (contains_factors(testing.set) &
              !handles_factors(model_object$method) &
              !model_object$method %in% c("svmRadial", "svmLinear")) {
            onehot <- TRUE
          }
          model_object %>%
            # estimate target in the testing set
            predict(newdata = if (onehot) testing.set.onehot else testing.set) %>%
            confusionMatrix(., observed) %>%
            .$overall %>%
            # tricky: convert first to dataframe > can select column names
            map_df(1) %>%
            select(Accuracy, Kappa)
        }
      ) %>%
      bind_rows(.id = "model") %>%
      setNames(c("model", "Acc.testing", "Kappa.testing"))
  } else if (is.numeric(observed)) {
    models_list %>%
      map_df(
        function(model_object) {
          # set flag for onehot encoding
          onehot <- FALSE
          # do onehot encoding for algorithms that cannot handle factors
          if (contains_factors(testing.set) &
              !handles_factors(model_object$method) &
              !model_object$method %in% c("svmRadial", "svmLinear")) {
            onehot <- TRUE
          }
          mean.training.set <- models_list[[1]]$trainingData$.outcome %>% mean
          predicted <- model_object %>%
            # estimate target in the testing set
            predict(newdata = if (onehot) testing.set.onehot else testing.set)
          c(
            # postResample(predicted, observed) %>% .["RMSE"],
            sqrt(mean((observed - predicted)^2)),
            # https://stackoverflow.com/a/36727900/7769076
            sum((predicted - mean.training.set)^2) / sum((observed - mean.training.set)^2),
            # R2 = regression SS / TSS
            ## sum((predicted - mean(predicted))^2) / sum((observed - mean(observed))^2),
            ## ?for centering, the same reference (observed) seems to be better?
            sum((predicted - mean(observed))^2) / sum((observed - mean(observed))^2),
            # postResample(predicted, observed) %>% .[("Rsquared")]
            cor(predicted, observed)^2
          )
        }) %>%
      t %>%
      as_tibble(rownames = "model") %>%
      dplyr::rename(RMSE.testing = V1, R2.testing = V2,
             R2.testing2 = V3,  R2.postResample= V4) %>%
      arrange(RMSE.testing) %>%
      as.data.frame
  }
}
################################################################################
# Visualize variable importance
# input caret::train object
################################################################################
visualize_importance <- function (
  model_object, # caret::train object
  relative = FALSE, # calculate relative importance scores (not normalized)
  axis_label = NULL, # label for vertical axis
  axis_tick_labels = NULL, # labels for items/facets/factors
  text_labels = FALSE, # labels showing numeric scores next to bar
  axis_limit = NULL, # max. axis score displayed
  width = 4, height = 3, dpi = 300, # specs for saved plot
  fill_color = "#114151",
  font_size = 10,
  save_label = "" # filename for saved plot
) {
  require(caret)
  require(gbm)
  require(dplyr)
  require(ggplot2)
  # calculate feature importance
  importance_object <- model_object %>% caret::varImp()
  unit.label <- ifelse(relative, "%RI", "importance")
  unit.variable <- rlang::sym(unit.label)
  if (class(importance_object) == "varImp.train") {
    importance_object %<>% .$importance
  }
  if (!hasName(importance_object, "rowname")) {
    importance_object %<>% rownames_to_column()
  }
  importance.table <- importance_object %>%
    dplyr::rename(variable = rowname, importance = Overall) %>%
    arrange(desc(importance)) %>%
    {
      if (relative) {
        dplyr::mutate(., `%RI` = importance/sum(importance)*100) %>%
          select(variable, `%RI`)
      } else {
        .
      }
    }
  importance.plot <- importance.table %>%
    set_names(c("variable", unit.label)) %>%
    ggplot(data = .,
           aes(x = reorder(variable, !!unit.variable), y = !!unit.variable)) +
    theme_minimal() +
    geom_bar(stat = "identity", fill = fill_color) +
    {
      if (text_labels) {
        geom_text(aes(label = round(!!unit.variable, digits = 2)),
                  position = position_dodge(width = 5),
                  hjust = -0.1,
                  check_overlap = TRUE,
                  # tricky: font size must be scaled down by ggplot2:::.pt
                  # https://stackoverflow.com/a/17312440/7769076
                  size = font_size / (ggplot2:::.pt * 1.1)
        )
      }
    } +
    coord_flip() +
    theme(axis.title = element_text(size = font_size),
          axis.text = element_text(size = font_size)) +
    {
      if (!is.null(axis_limit)) {
        scale_y_continuous(expand = c(0, 0),
                           limits = c(0, axis_limit))
      }
    } +
    {
      if (!is.null(axis_tick_labels)) {
        scale_x_discrete(labels = axis_tick_labels)
      }
    } +
    labs(
      x = axis_label,
      y = unit.label
    )
  if (save_label != "") {
    ggsave(
      filename = save_label,
      plot = importance.plot,
      dpi = dpi,
      width = width,
      height = height
    )
  }
  return(
    list(
      importance.table = importance.table,
      importance.plot = importance.plot
    ))
}
Add the following code to your website.
For more information on customizing the embed code, read Embedding Snippets.